import os
import re
import jieba
from numpy import *
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import load_model


#预定义变量
MAX_SEQUENCE_LENGTH = 100    #最大序列长度

def readFlie(path):    #读取一个样本的记录，默认一个文件一条样本
    with open(path,'r',errors='ignore') as file:
        content = file.read()
        file.close()
        return content

def getStopWord(inputFile):  #获取停用词表
    stopWordList = readFlie(inputFile).splitlines()
    return stopWordList

def remove_punctuation(line):
    line = str(line)
    if line.strip()=='':
        return ''
    rule = re.compile(u"[^a-zA-Z0-9\u4E00-\u9FA5]")
    line = rule.sub('',line)
    return line

def predict(text,model,stopWordList):
    stopwords = getStopWord(stopWordList)
    tokenizer = Tokenizer()
    txt = remove_punctuation(text)
    txt = [" ".join([w for w in list(jieba.cut(txt)) if w not in stopwords])]
    tokenizer.fit_on_texts(txt)
    print(txt)
    seq = tokenizer.texts_to_sequences(txt)
    padded = pad_sequences(seq, maxlen=MAX_SEQUENCE_LENGTH)
    print(seq)
    pred = model.predict(padded)
    print(pred)
    cat_id = pred.argmax(axis=1)
    return cat_id

if __name__ == '__main__':
    model = load_model('cnn.h5')
    stopWord_path = "./stop/stopword.txt"  # 停用词路径
    print(predict("人民银行长沙中心支行快速响应 做好疫情期间金融服务",model,stopWord_path))